import os
import numpy as np
import torch
import logging
import torch.nn as nn
import torchvision.models as models

from .cnn import *
from .resnet import *

def build_model(args):
    if args.dataset in ['mnist', 'MNIST']:
        model = CNN(dataset='MNIST', mode=args.mode, weight=args.weight, last_relu=args.last_relu, need_linear=args.need_linear)
    elif args.dataset in ['cifar10', 'CIFAR10']:
        model = resnet18(num_classes=10, mode=args.mode, weight=args.weight, last_relu=args.last_relu, need_linear=args.need_linear)
    elif args.dataset in ['cifar100', 'CIFAR100']:
        model = resnet34(num_classes=100, mode=args.mode, weight=args.weight, last_relu=args.last_relu, need_linear=args.need_linear)
    else:
        raise NotImplementedError
    return model